/*
 * Wazuh Vulnerability scanner - Scan Orchestrator
 * Copyright (C) 2015, Wazuh Inc.
 * March 25, 2023.
 *
 * This program is free software; you can redistribute it
 * and/or modify it under the terms of the GNU General Public
 * License (version 2) as published by the FSF - Free Software
 * Foundation.
 */

#ifndef _SCAN_ORCHESTRATOR_HPP
#define _SCAN_ORCHESTRATOR_HPP

#include "factoryOrchestrator.hpp"
#include "flatbuffers/include/messageBuffer_generated.h"
#include "flatbuffers/include/rsync_generated.h"
#include "flatbuffers/include/syscollector_deltas_generated.h"
#include "indexerConnector.hpp"
#include "scanContext.hpp"
#include "wdbDataException.hpp"
#include <memory>
#include <string>
#include <variant>

constexpr auto INVENTORY_DB_PATH = "queue/vd/inventory";
constexpr auto DELAYED_EVENTS_BULK_SIZE {1};
constexpr auto DELAYED_QUEUE_PATH = "queue/vd/delayed";
constexpr auto DELAYED_POSTPONE_SECONDS {60};

using EventDispatcher = TThreadEventDispatcher<rocksdb::Slice,
                                               rocksdb::PinnableSlice,
                                               std::function<void(std::queue<rocksdb::PinnableSlice>&)>>;

using EventDelayedDispatcher =
    TThreadEventDispatcher<rocksdb::Slice,
                           rocksdb::PinnableSlice,
                           std::function<void(rocksdb::PinnableSlice&)>,
                           RocksDBQueueCF<rocksdb::Slice, rocksdb::PinnableSlice>,
                           Utils::TSafeMultiQueue<rocksdb::Slice,
                                                  rocksdb::PinnableSlice,
                                                  RocksDBQueueCF<rocksdb::Slice, rocksdb::PinnableSlice>>>;

/**
 * @brief ScanOrchestrator class.
 *
 */
template<typename TScanContext = ScanContext,
         typename TFactoryOrchestrator = FactoryOrchestrator,
         typename TOrchestrationNode = AbstractHandler<std::shared_ptr<TScanContext>>,
         typename TIndexerConnector = IndexerConnector,
         typename TDatabaseFeedManager = DatabaseFeedManager,
         typename TOSPrimitives = OSPrimitives,
         typename TSocketDBWrapper = SocketDBWrapper,
         auto DelayedPostponeSeconds = DELAYED_POSTPONE_SECONDS>
class TScanOrchestrator final : public TOSPrimitives
{
public:
    /**
     * @brief Class constructor.
     *
     * @param indexerConnector Indexer connector.
     * @param databaseFeedManager Database feed manager.
     * @param reportDispatcher Report dispatcher queue to send vulnerability reports.
     * @param mutex Mutex to protect the access to the internal databases.
     */
    // LCOV_EXCL_START
    explicit TScanOrchestrator(std::shared_ptr<TIndexerConnector> indexerConnector,
                               std::shared_ptr<TDatabaseFeedManager> databaseFeedManager,
                               std::shared_ptr<ReportDispatcher> reportDispatcher,
                               std::shared_mutex& mutex)
        : m_mutex {mutex}
    {
        m_inventoryDatabase = std::make_unique<Utils::RocksDBWrapper>(INVENTORY_DB_PATH);
        auto& inventoryDatabase = *m_inventoryDatabase;

        m_osOrchestration = TFactoryOrchestrator::create(
            ScannerType::Os, databaseFeedManager, indexerConnector, inventoryDatabase, reportDispatcher);
        m_packageInsertOrchestration = TFactoryOrchestrator::create(
            ScannerType::PackageInsert, databaseFeedManager, indexerConnector, inventoryDatabase, reportDispatcher);
        m_packageDeleteOrchestration = TFactoryOrchestrator::create(
            ScannerType::PackageDelete, databaseFeedManager, indexerConnector, inventoryDatabase, reportDispatcher);
        m_integrityClearOrchestration = TFactoryOrchestrator::create(
            ScannerType::IntegrityClear, databaseFeedManager, indexerConnector, inventoryDatabase, reportDispatcher);
        m_reScanAllOrchestration = TFactoryOrchestrator::create(
            ScannerType::ReScanAllAgents, databaseFeedManager, indexerConnector, inventoryDatabase, reportDispatcher);
        m_reScanOrchestration = TFactoryOrchestrator::create(
            ScannerType::ReScanSingleAgent, databaseFeedManager, indexerConnector, inventoryDatabase, reportDispatcher);
        m_deleteAgentScanOrchestration = TFactoryOrchestrator::create(ScannerType::CleanupSingleAgentData,
                                                                      databaseFeedManager,
                                                                      indexerConnector,
                                                                      inventoryDatabase,
                                                                      reportDispatcher);
        m_cleanUpDataOrchestration = TFactoryOrchestrator::create(ScannerType::CleanupAllAgentData,
                                                                  databaseFeedManager,
                                                                  indexerConnector,
                                                                  inventoryDatabase,
                                                                  reportDispatcher);
        m_inventorySyncOrchestration = TFactoryOrchestrator::create(ScannerType::GlobalSyncInventory,
                                                                    databaseFeedManager,
                                                                    indexerConnector,
                                                                    inventoryDatabase,
                                                                    reportDispatcher);
        // coverity[copy_constructor_call]
        m_hotfixInsertOrchestration = TFactoryOrchestrator::create(
            ScannerType::HotfixInsert, databaseFeedManager, indexerConnector, inventoryDatabase, reportDispatcher);

        // Define the maximum size for the hostname
        constexpr auto MAX_HOSTNAME_SIZE = 256;
        char managerNameRaw[MAX_HOSTNAME_SIZE] = {0};

        // Get the hostname and store it in the managerName string
        TOSPrimitives::gethostname(managerNameRaw, MAX_HOSTNAME_SIZE);

        GlobalData::instance().managerName(managerNameRaw);

        initEventDelayedDispatcher();
    }
    ~TScanOrchestrator() = default;
    // LCOV_EXCL_STOP

    /**
     * @brief Start the delayed event dispatcher
     */
    void initEventDelayedDispatcher()
    {
        m_eventDelayedDispatcher = std::make_shared<EventDelayedDispatcher>(
            // coverity[copy_constructor_call]
            [this](rocksdb::PinnableSlice& element)
            {
                const auto parseEventMessage = [](const rocksdb::PinnableSlice& elementToParse) -> std::string
                {
                    if (const auto eventMessageBuffer = GetMessageBuffer(elementToParse.data()); eventMessageBuffer)
                    {
                        return std::string(eventMessageBuffer->data()->begin(), eventMessageBuffer->data()->end());
                    }

                    return "unable to parse";
                };

                try
                {
                    processEvent(element, true);
                }
                catch (const WdbDataException& e)
                {
                    m_eventDelayedDispatcher->postpone(e.agentId(), std::chrono::seconds(DELAYED_POSTPONE_SECONDS));
                    logDebug2(WM_VULNSCAN_LOGTAG, "Postponed delayed event for agent %s", e.agentId().c_str());
                    throw std::runtime_error(e.what());
                }
                catch (const AgentReScanListException& e)
                {
                    logDebug2(WM_VULNSCAN_LOGTAG, "AgentReScanListException. Reason: %s", e.what());
                    for (const auto& agentData : e.agentList())
                    {
                        pushReScanToDelayedDispatcher(agentData.id, e.noIndex());
                        m_eventDelayedDispatcher->postpone(agentData.id,
                                                           std::chrono::seconds(DELAYED_POSTPONE_SECONDS));
                    }
                }
                catch (const nlohmann::json::exception& e)
                {
                    logError(WM_VULNSCAN_LOGTAG,
                             "ScanOrchestrator::initEventDelayedDispatcher: json exception (%d) - Event "
                             "message: %s",
                             e.id,
                             parseEventMessage(element).c_str());
                }
                catch (const std::exception& e)
                {
                    logError(WM_VULNSCAN_LOGTAG, "Error processing delayed event: %s.", e.what());
                }
            },
            DELAYED_QUEUE_PATH);
    }

    /**
     * @brief Push an event to the delayed dispatcher.
     * @param element Event to push.
     * @param agentId Agent ID.
     */
    void pushEventToDelayedDispatcher(const rocksdb::PinnableSlice& element, const std::string& agentId)
    {
        m_eventDelayedDispatcher->push(agentId, element);
    }

    /**
     * @brief Push a rescan event to the delayed dispatcher.
     * @param agentId Agent ID.
     * @param noIndex Flag to indicate if the elements should be indexed.
     */
    void pushReScanToDelayedDispatcher(const std::string& agentId, const bool noIndex)
    {
        nlohmann::json dataValueJson;
        dataValueJson["action"] = "scanAgent";
        dataValueJson["agent_info"]["agent_id"] = agentId;
        dataValueJson["no-index"] = noIndex;

        std::string dataValue = dataValueJson.dump();
        const std::vector<char> message(dataValue.begin(), dataValue.end());

        flatbuffers::FlatBufferBuilder builder;
        auto object = CreateMessageBufferDirect(builder,
                                                reinterpret_cast<const std::vector<int8_t>*>(&message),
                                                BufferType::BufferType_JSON,
                                                Utils::getSecondsFromEpoch());

        builder.Finish(object);

        m_eventDelayedDispatcher->push(agentId,
                                       {reinterpret_cast<const char*>(builder.GetBufferPointer()), builder.GetSize()});
    }

    /**
     * @brief Process an event.
     *
     * @param input Event to process.
     * @param isDelayed Flag to indicate if the event is delayed.
     */
    void processEvent(const rocksdb::PinnableSlice& input, const bool isDelayed = false) const
    {
        auto message = GetMessageBuffer(input.data());

        if (message->type() == BufferType::BufferType_RSync)
        {
            std::variant<const SyscollectorDeltas::Delta*, const Synchronization::SyncMsg*, const nlohmann::json*>
                data = Synchronization::GetSyncMsg(message->data()->data());

            run(data, input, isDelayed);
        }
        else if (message->type() == BufferType::BufferType_DBSync)
        {
            std::variant<const SyscollectorDeltas::Delta*, const Synchronization::SyncMsg*, const nlohmann::json*>
                data = SyscollectorDeltas::GetDelta(message->data()->data());

            run(data, input, isDelayed);
        }
        else if (message->type() == BufferType::BufferType_JSON)
        {
            auto jsonData =
                nlohmann::json::parse(message->data()->data(), message->data()->data() + message->data()->size());
            std::variant<const SyscollectorDeltas::Delta*, const Synchronization::SyncMsg*, const nlohmann::json*>
                data = &jsonData;

            run(data, input, isDelayed);
        }
        else
        {
            throw std::runtime_error("Unknown event type");
        }
    }

private:
    /**
     * @brief Runs orchestrator, decoding and building context.
     *
     * @param data Data to process.
     * @param rawData Raw data to process.
     * @param isDelayed Flag to indicate if the event is delayed.
     */
    void
    run(std::variant<const SyscollectorDeltas::Delta*, const Synchronization::SyncMsg*, const nlohmann::json*> data,
        const rocksdb::PinnableSlice& rawData,
        bool isDelayed) const
    {
        // The scan only reads the content. Only one thread can run an orchestration at a time.
        std::scoped_lock lock(m_mutex);

        auto context = std::make_shared<TScanContext>(data);
        const auto type = context->getType();

        // Postpone the event if:
        // - Wasn't previously postponed
        // - There are delayed events for that agent
        // - The event is not a cleanup or re-scan all agents event
        // Cleanup and re-scan events purge the delayed queue, so they have priority to be processed
        if (!isDelayed && type != ScannerType::CleanupAllAgentData && type != ScannerType::ReScanAllAgents &&
            type != ScannerType::CleanupSingleAgentData && m_eventDelayedDispatcher->size(context->agentId()))
        {
            logDebug2(WM_VULNSCAN_LOGTAG, "Postponing event '%d' for agent '%s'", type, context->agentId().data());
            m_eventDelayedDispatcher->push(context->agentId(), rawData);
        }
        else
        {
            switch (type)
            {
                case ScannerType::HotfixInsert:
                    logDebug2(WM_VULNSCAN_LOGTAG,
                              "Processing 'HotfixInsert' event for agent '%s'",
                              context->agentId().data());
                    m_hotfixInsertOrchestration->handleRequest(std::move(context));
                    break;
                case ScannerType::PackageInsert:
                    logDebug2(WM_VULNSCAN_LOGTAG,
                              "Processing 'PackageInsert' event for agent '%s'",
                              context->agentId().data());
                    m_packageInsertOrchestration->handleRequest(std::move(context));
                    break;
                case ScannerType::PackageDelete:
                    logDebug2(WM_VULNSCAN_LOGTAG,
                              "Processing 'PackageDelete' event for agent '%s'",
                              context->agentId().data());
                    m_packageDeleteOrchestration->handleRequest(std::move(context));
                    break;
                case ScannerType::Os:
                    logDebug2(
                        WM_VULNSCAN_LOGTAG, "Processing 'OS scan' event for agent '%s'", context->agentId().data());
                    m_osOrchestration->handleRequest(std::move(context));
                    break;
                case ScannerType::IntegrityClear:
                    logDebug2(WM_VULNSCAN_LOGTAG,
                              "Processing 'IntegrityClear' event for agent '%s'",
                              context->agentId().data());
                    m_integrityClearOrchestration->handleRequest(std::move(context));
                    break;
                // LCOV_EXCL_START
                case ScannerType::ReScanAllAgents:
                    logDebug2(WM_VULNSCAN_LOGTAG,
                              "Re-scan all agents event received. Initiating re-scan for all agents");
                    m_reScanAllOrchestration->handleRequest(std::move(context));
                    m_eventDelayedDispatcher->clear();
                    break;
                case ScannerType::ReScanSingleAgent:
                    logDebug2(WM_VULNSCAN_LOGTAG,
                              "Processing 'ReScanSingleAgent' event for agent '%s'",
                              context->agentId().data());
                    m_reScanOrchestration->handleRequest(context);
                    m_eventDelayedDispatcher->clear(context->agentId());
                    break;
                case ScannerType::CleanupAllAgentData:
                    logDebug2(WM_VULNSCAN_LOGTAG, "Clean-up all data event received. Cleaning up data for all agents");
                    m_cleanUpDataOrchestration->handleRequest(std::move(context));
                    m_eventDelayedDispatcher->clear();
                    break;
                case ScannerType::CleanupSingleAgentData:
                    logDebug2(WM_VULNSCAN_LOGTAG,
                              "Processing 'CleanupSingleAgentData' event for agent '%s'",
                              context->agentId().data());
                    m_deleteAgentScanOrchestration->handleRequest(context);
                    m_eventDelayedDispatcher->clear(context->agentId());
                    break;
                case ScannerType::GlobalSyncInventory:
                    logDebug2(WM_VULNSCAN_LOGTAG,
                              "Processing 'GlobalSyncInventory' event to synchronize inventory across nodes");
                    m_inventorySyncOrchestration->handleRequest(std::move(context));
                    break;
                // LCOV_EXCL_STOP
                default: return;
            }
        }
        logDebug2(WM_VULNSCAN_LOGTAG, "Event type: %d processed", type);
    }

    /**
     * @brief Indexer connector.
     *
     */
    std::unique_ptr<Utils::RocksDBWrapper> m_inventoryDatabase;
    std::shared_ptr<TOrchestrationNode> m_osOrchestration;
    std::shared_ptr<TOrchestrationNode> m_packageInsertOrchestration;
    std::shared_ptr<TOrchestrationNode> m_packageDeleteOrchestration;
    std::shared_ptr<TOrchestrationNode> m_hotfixInsertOrchestration;
    std::shared_ptr<TOrchestrationNode> m_hotfixDeleteOrchestration;
    std::shared_ptr<TOrchestrationNode> m_integrityClearOrchestration;
    std::shared_ptr<TOrchestrationNode> m_fetchAllFromGlobalDbOrchestration;
    std::shared_ptr<TOrchestrationNode> m_reScanAllOrchestration;
    std::shared_ptr<TOrchestrationNode> m_reScanOrchestration;
    std::shared_ptr<TOrchestrationNode> m_cleanUpDataOrchestration;
    std::shared_ptr<TOrchestrationNode> m_deleteAgentScanOrchestration;
    std::shared_ptr<TOrchestrationNode> m_inventorySyncOrchestration;
    std::shared_mutex& m_mutex;
    std::shared_ptr<EventDelayedDispatcher> m_eventDelayedDispatcher;
};

using ScanOrchestrator = TScanOrchestrator<>;

#endif // _SCAN_ORCHESTRATOR_HPP
